package ci.web.codec;


import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

import java.net.InetSocketAddress;
import java.util.HashSet;

/**
 * 检测host是否在配置允许的域里<br/>
 * 防止服务器被域名恶意指向
 * @author zhh
 */
@ChannelHandler.Sharable
public class HostHandler extends ChannelDuplexHandler {

    private static final InternalLogger logger = InternalLoggerFactory.getInstance(HostHandler.class);
    
    
    /**
     * 是否允许同一跟域名
     */
    public static boolean AlowSameRootDomain = true;
    /**
     * 配置-hosts
     * @param hostsConfig
     * @return
     */
    public static HostHandler make(String hostsConfig){
        if(hostsConfig==null || hostsConfig.trim().isEmpty()){
            return null;
        }
        String[] arr = hostsConfig.split("[;,]");
        HashSet<String> set = new HashSet<String>();
        for(String n:arr){
            n = formatHost(n.trim());
            if(n.length()>0){
                set.add(n);
            }
        }
        arr = set.toArray(new String[set.size()]);
        if(arr.length==0){
            return null;
        }
        if(arr.length==1){
            if(arr[0].equals("*")){
                return null;
            }
            return new HostHandler(arr[0]);
        }
        return new HostHandler(arr);
    }
    
    protected final String host;
    protected final String[] hosts;
    /**
     * @param host 白名单host
     */
    protected HostHandler(String host){
        this.host = host;
        this.hosts = null;
    }
    /**
     * @param hosts 白名单hosts
     */
    protected HostHandler(String[] hosts){
        this.host = null;
        this.hosts = hosts;
    }
    
    @Override
    public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
        if(msg instanceof HttpRequest){
            ctx.fireChannelRead(msg);
            ctx.pipeline().remove(this);
            checkAlow(ctx, (HttpRequest)msg);
        }else{
            ctx.fireChannelRead(msg);
        }
    }
    /**
     * 检查处理
     * @param ctx
     * @param request
     * @return
     */
    protected boolean checkAlow(final ChannelHandlerContext ctx, final HttpRequest request) {
        String target = request.headers().getAsString(HttpHeaderNames.HOST);
        int idx = target.indexOf(':');
        if(idx>0){
            target = target.substring(0, idx);
        }
        boolean ret = false;
        if(host!=null){
            ret = equalsHost(host, target, AlowSameRootDomain);
        }else{
            for(String node : hosts){
                if(equalsHost(node, target, AlowSameRootDomain)){
                    ret = true;
                    break;
                }
            }
        }
        if(ret==false){
            forbidden(ctx, target, request);
        }
        return ret;
    }
    /**
     * 格式化域名
     * @param host
     * @return
     */
    protected static String formatHost(String host){
        String d = host;
        if(host.startsWith("*.")){
            d = host.substring(2);
        }else if(host.charAt(0)=='.'){
            d = host.substring(1);
        }else if(host.startsWith("www.")){
            d = host.substring(4);
        }
        return d;
    }
    /**
     * 比较域名
     * @param host
     * @param target
     * @param sameDomainAlow
     * @return
     */
    protected static boolean equalsHost(String host, String target, boolean sameDomainAlow){
        if(sameDomainAlow){
            if(host.equals(target)){
                return true;
            }
            if(target.length()<=(host.length()+1)){
                return false;
            }
            return target.endsWith(host) && target.charAt(target.length()-host.length()-1)=='.';
        }
        return host.equals(target);
    }
    /**
     * 向ctx写forbidden响应，并记录日志
     * @param ctx
     * @param host
     * @param request
     */
    protected static void forbidden(final ChannelHandlerContext ctx, String host, final HttpRequest request){
        ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.FORBIDDEN)).addListener(ChannelFutureListener.CLOSE);
        logger.info("BadHost:{}@{}", host, ((InetSocketAddress)ctx.channel().remoteAddress()).getAddress().getHostAddress());
    }
}
